#!/usr/bin/env python

# Copyright 2014 The Souper Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import fileinput, sys, copy
import collections

widths, knownzeros, knownones, Negative, nonNegative, powerTwo, nonZero, signBits, lower, upper = dict(), dict(), dict(), dict(), dict(), dict(), dict(), dict(), dict(), dict()
pcmap = collections.OrderedDict()
argmap = dict()
cand = None
needsResult = False
intrdecl = dict()
printedblocks = []
labels = ["entry"]
cnt = 0
cnt2 = 0
cnt3 = 0
independent = []
flagPC = 0
MDRangePtr = 0
metadataStr = ""
functionDeclStr = ""
functionDeclCount = 0
truncCount = 0

# souper IR -> LLVM inst map
instmap = {
    "addnsw"    : "add nsw",
    "addnuw"    : "add nuw",
    "addnw"     : "add nsw nuw",
    "subnsw"    : "sub nsw",
    "subnuw"    : "sub nuw",
    "subnw"     : "sub nsw nuw",
    "mulnsw"    : "mul nsw",
    "mulnuw"    : "mul nuw",
    "mulnw"     : "mul nsw nuw",
    "udivexact" : "udiv exact",
    "sdivexact" : "sdiv exact",
    "shlnsw"    : "shl nsw",
    "shlnuw"    : "shl nuw",
    "shlnw"     : "shl nsw nuw",
    "lshrexact" : "lshr exact",
    "ashrexact" : "ashr exact",
    "eq"        : "icmp eq",
    "ne"        : "icmp ne",
    "ult"       : "icmp ult",
    "slt"       : "icmp slt",
    "ule"       : "icmp ule",
    "sle"       : "icmp sle",
}

# special inst handling
nestedinsts = ["add", "mul", "and", "or", "xor"]
boolinsts = ["eq", "ne", "ugt", "uge", "ult", "ule", "sgt", "sge", "slt", "sle"]
towidthinsts = ["sext", "trunc", "zext"];
intrinsts = ["bswap", "ctpop", "cttz", "ctlz", "fshl", "fshr",
             "sadd.sat", "uadd.sat", "ssub.sat", "usub.sat",
             "sadd.with.overflow", "uadd.with.overflow", "ssub.with.overflow",
             "usub.with.overflow", "smul.with.overflow", "umul.with.overflow"]
specialinsts = ["block", "pc", "cand", "infer", "result"]

def parseConst(c):
    assert ":" in c, "wrong constant format %s" % c
    tmp = c.split(':')
    constant = tmp[0]
    width = tmp[1]
    assert width[0] == "i", "wrong type format %s" % width
    return [constant, width]

def parseOps(ops):
    r = []
    for op in ops:
        op = op.strip().strip(",")
        if op[0] == "%":
            r.append([op, widths[op]])
        elif op == "(hasExternalUses)": #for (hasExternalUses)
            continue
        else:
            r.append(parseConst(op))
    return r

def parseInst(i):
    reg, width, inst, ops = None, None, None, []
    assert len(i) >= 2, "wrong inst format %s" % i
    if i[0] == "infer":
        #demandedbits
        assert (len(i) >= 2 & len(i) <= 3), "wrong infer inst length %d, %s" % (len(i), i)
        inst = i[0]
        reg = i[1]
        width = widths[reg]
        #ops = []
        #ops = []
        #demandedbits
        if len(i) == 3:
            ops = i[2:]
        else:
            ops = []
    elif i[0] == "result":
        assert len(i) == 2, "wrong result inst length %d, %s" % (len(i), i)
        inst = i[0]
        ops = parseOps([i[1]])
    elif i[2] == "block":
        assert len(i) == 4, "wrong block inst length %d, %s" % (len(i), i)
        reg = i[0]
        inst = i[2]
        if reg not in widths:
            widths[reg] = None
        ops = [i[3]]
    elif i[0] == "cand":
        assert len(i) == 3, "wrong cand inst length %d, %s" % (len(i), i)
        inst = i[0]
        reg = i[1]
        ops = parseOps([i[2]])
    elif i[0] == "pc":
        assert len(i) == 3, "wrong instruction length %d, %s" % (len(i), i)
        inst = i[0]
        reg = i[1]
        ops = [parseConst(i[2])]
    elif i[0][0] == "%":
        assert len(i) >= 3, "wrong instruction length %d, %s" % (len(i), i)
        tmp = i[0].split(":")
        assert len(tmp) == 2, "wrong reg length %d, %s" % (len(tmp), tmp)
        assert i[1] == "=", "expecting '=', got %s" % i[1]
        reg = tmp[0]
        width = tmp[1]
        inst = i[2]
        if "overflow" in i[2]:
            width = "i" + str(int(width[1:])-1)
        widths[reg] = width
        if inst == "var":
            ops = i[3:]
        else:
            ops = parseOps(i[3:])
    else:
        assert 0, "unknown LHS: %s" % i[0]
    return [reg, width, inst, ops]


newlabelcounter = 0

def translateInstToLLVM(i):
    reg, width, inst, ops = i
    s = ""
    if inst == "phi":
        block = ops[0][0]
        philabel = "philabel" + str(block[1:])
        if not block in printedblocks:
            global cnt3
            t = cnt3
            blockcount = len(ops[1:])
            if blockcount == 1:
                s += "  br label %" + philabel + "\n"
            if blockcount >= 2:
                global newlabelcounter
                for i in range(blockcount):
                    if i == blockcount - 1:
                        s += "  br label %foo" + str(t) + "\n"
                    else:
                        s += "  %sw" +str(newlabelcounter) + " = icmp eq i64 %_phiinput, " + str(newlabelcounter) + "\n"
                        s += "  br i1 %sw" +str(newlabelcounter) +" , label %nl"+ str(newlabelcounter) +", label %foo" + str(t) + "\n"
                        s += "nl"+ str(newlabelcounter) +":\n"
                        newlabelcounter += 1
                        t += 1

            printedblocks.append(block)

            for k in range(len(ops[1:])):
                # don't print the first label of the func
                label = "foo" + str(cnt3)
                cnt3 += 1
                labels.append(label)
                s += label + ":\n"
                s += "  br label %" + philabel + "\n"
            s += philabel + ":\n"
        s += "  " + reg + " = phi " + width + " "
        # grab the necessary number of predecessor labels
        blocklabels = labels[-len(ops[1:]):]
        for k, op in enumerate(ops[1:]):
            s += "[" + op[0] + ", %" + blocklabels[k] + "]"
            if k < len(ops[1:])-1:
                s += ", "
    else:
        s = "  " + reg + " = "
        if inst in instmap:
            s += instmap[inst]
        elif not inst in intrinsts:
            s += inst
        # special cases
        if inst in boolinsts:
            assert len(ops) >= 1, "must have at least one op %s" % inst
            s += " " + ops[0][1]
        elif inst in towidthinsts:
            assert len(ops) == 1, "must have exactly one op %s" % inst
            s += " " + ops[0][1] + " " + ops[0][0] + " to " + width
        elif inst == "extractvalue":
            # extracting i1s only
            s += " {" + ops[0][1] + ", i1}"
        elif inst in intrinsts:
            if "overflow" in inst:
                s += "call {" + width + ", i1} @llvm." + inst + "." + width \
                              + "(" + ops[0][1] + " " + ops[0][0] + ", " \
                              + ops[1][1] + " " + ops[1][0] + ")"
            else:
                s += "call " + width + " @llvm." + inst + "." + width + "(" \
                             + ops[0][1] + " " + ops[0][0] + ")"
        elif inst == "select":
            pass
        else:
            s += " " + width
        # write ops
        if inst not in towidthinsts and inst not in intrinsts:
            for op in ops:
                if inst == "select":
                    s += " " + op[1] + " " + op[0] + ","
                else:
                    s += " " + op[0] + ","
    return s.strip(",")

def readOpt():
    res = []
    for line in fileinput.input():
        if not line.strip():
            continue
        elif line.strip()[0] == ";":
            continue
        res.append(line.split())
    assert len(res), "empty file"
    return res

def parseInsts(lines):
    res = []
    for line in lines:
        res.append(parseInst(line))
    return res

def propagateArgNames(insts):
    res = []
    counter = 1
    for i in insts:
        reg, width, inst, ops = i
        if inst == "var":
            newreg = "%x" + str(counter)
            argmap[reg] = newreg
            widths[newreg] = width
            widthValue = width.strip("i")
            if len(ops):
                assert len(ops) <= 7, "atmost six dataflow facts expected, got %d" % len(ops)
                for j in ops:
                    if "knownBits" in j:
                        parseKnownBits(j, newreg)
                    elif "negative" in j:
                        parseNegative(j, newreg, int(widthValue))
                    elif "nonNegative" in j:
                        parseNonNegative(j, newreg, int(widthValue))
                    elif "powerOfTwo" in j:
                        parsePowerTwo(j, newreg, int(widthValue))
                    elif "nonZero" in j:
                        parsenonZero(j, newreg, int(widthValue))
                    elif "signBits" in j:
                        parsesignBits(j, newreg, int(widthValue))
                    elif "range" in j:
                        parseRange(j, newreg, int(widthValue))
            counter += 1
            continue
        for k, op in enumerate(ops):
            if op[0] in argmap:
                ops[k][0] = argmap[op[0]]
        res.append(i)
    return res

# topology sort
def sort_direct_acyclic_graph(edge_list) :
    # edge_set is consummed, need a copy
    edge_set = set([tuple(i) for i in edge_list])

    # node_list will contain the ordered nodes
    node_list = list()

    # source_set is the set of nodes with no incomming edges
    node_from_list, node_to_list = zip(* edge_set)
    source_set = set(node_from_list) - set(node_to_list)

    while len(source_set) != 0 :
	# pop node_from off source_set and insert it in node_list
	node_from = source_set.pop()
	node_list.append(node_from)

	# find nodes which have a common edge with node_from
	from_selection = [e for e in edge_set if e[0] == node_from]
	for edge in from_selection :
	    # remove the edge from the graph
	    node_to = edge[1]
	    edge_set.discard(edge)

	    # if node_to don't have any remaining incomming edge :
	    to_selection = [e for e in edge_set if e[1] == node_to]
	    if len(to_selection) == 0 :
		# add node_to to source_set
		source_set.add(node_to)

    if len(edge_set) != 0 :
	raise IndexError # not a direct acyclic graph
    else :
	return node_list


def reorder(insts):
    early = []
    pc = []
    late = []
    origin_insts = []
    origin_insts += insts

    for i in insts:
        if i[2] == 'block':
            early.append(i)
        if i[2] == 'pc':
            pc.append(i)
        if i[2] == 'cand' or i[2] == 'infer' or i[2] == 'result':
            late.append(i)


    blocks = dict()
    regmap = getRegMap(insts)
    # scan for blocks and phis
    for k in range(len(insts)):
        i = insts[k]
        reg, width, inst, ops = i
        if inst == "block":
            assert not reg in blocks, "block %s already parsed" % reg
            blocks[reg] = []
            # move block def to front
            insts.insert(0, insts.pop(k))
        # skip pc/cand/infer/result insts
        elif inst in specialinsts:
            pass
        elif inst == "phi":
            block = ops[0][0]
            assert block in blocks, "unknown block %s" % block
            blocks[block].append(i)

    queue = insts
    edges = set()
    visited = []

    #get dependency edges
    while len(queue):
        i = queue.pop()
        reg, width, inst, ops = i
        new_ops = []
        if i in visited:
            continue
        visited.append(i)

        # put args in new_ops,
        if inst == "phi":
            block = ops[0][0]
            for phi in blocks[block]:
                for k, op in enumerate(phi[3]):
                    if k == 0:
                        continue
                    if op[0] in regmap:
                        new_ops.append(regmap[op[0]])
        else:
            for k, op in enumerate(ops):
                if op[0] in regmap:
                    new_ops.append(regmap[op[0]])

        for k, op in enumerate(new_ops):
            dep = op[0]
            # process on each operands
            if dep[0] == "%" and dep in regmap:
                op_inst = regmap[dep]
                if op_inst[2] == "phi":
                    dep_inst_block = op_inst[3][0][0]
                    for dep_phi in blocks[dep_inst_block]:
                        edges.add((reg, dep_phi[0]))
                        queue.append(regmap[dep_phi[0]])
                else:
                    edges.add((reg, regmap[dep][0]))
                    queue.append(regmap[dep])

    unsorted_dag = []
    for edge in edges:
        unsorted_dag.append(edge)

    dag = []
    if len(edges) != 0:
        dag = sort_direct_acyclic_graph(unsorted_dag)
    res = []

    for inst in origin_insts:
        if inst[2] in specialinsts:
            continue
        if inst[0] in dag:
            continue
        res.append(inst)

    for v in dag:
        if regmap[v] in res:
            continue
        reg, width, inst, ops = regmap[v]
        if inst == "phi":
            block = ops[0][0]
            for phi in blocks[block]:
                res.append(phi)
        else:
            res.append(regmap[v])

    res.reverse()

    res = early + res + pc + late

    return res


def getRegMap(insts):
    res = dict()
    for i in insts:
        reg, width, inst, ops = i
        if inst in specialinsts:
            continue
        if not reg in res:
            res[reg] = i
    return res

def getInstDeps(regmap, i, blocks):
    res = []
    reg, width, inst, ops = i
    queue = []
    if inst == "phi":
        block = ops[0][0]
        for phi in blocks[block]:
            queue.append(phi)
    else:
        queue.append(i)
    visited = set()
    while len(queue):
        reg, width, inst, ops = queue.pop()
        if reg in visited:
            continue
        visited.add(reg)

        assert not inst in specialinsts

        for k, op in enumerate(ops):
            dep = op[0]
            if dep[0] == "%" and dep in regmap:
                newop = regmap[dep]
                nreg, nwidth, ninst, nops = newop
                if ninst == "phi":
                    nblock = nops[0][0]
                    for nphi in blocks[nblock]:
                        if nphi[0] in visited:
                            continue
                        res.append(nphi)
                        queue.append(nphi)
                else:
                    if newop[0] in visited:
                        continue
                    res.append(newop)
                    queue.append(newop)

    return res

def rewriteRegs(insts):
    counter = 0
    res = []
    newregs = copy.deepcopy(argmap)
    global cand
    global needsResult
    for i in insts:
        reg, width, inst, ops = i
        if inst == "block":
            continue
        elif inst == "cand":
            assert not cand, "cand must be empty"
            assert reg in newregs, "unknown reg %s for %s" % (reg, i)
            rewriteOps(inst, newregs, ops, i)
            cand = [newregs[reg], width, inst, ops]
            continue
        elif inst == "infer":
            assert not cand, "cand must be empty"
            cand = [newregs[reg], width, inst, ops]
            needsResult = True
            break
        elif inst == "result":
            assert needsResult, "infer inst not found"
            assert cand, "cand must not be empty"
            rewriteOps(inst, newregs, ops, i)
            cand[3] = ops
            needsResult = False
            continue
        elif inst == "pc":
            assert reg in newregs, "unknown reg %s for %s" % (reg, i)
            newreg = newregs[reg]
            pc = [newreg, width, inst, ops]
            if newreg in pcmap:
                pcmap[newreg].append(pc)
            else:
                pcmap[newreg] = [pc]
            global flagPC
            flagPC = 1
            continue
        # give regs new names
        newreg = "%" + str(counter)
        counter += 1
        newregs[reg] = newreg
        widths[newreg] = width
        rewriteOps(inst, newregs, ops, i)
        # special case: unroll nested insts
        if inst in nestedinsts and len(ops) > 2:
            res.append([newreg, width, inst, [ops[0], ops[1]]])
            for op in ops[2:]:
                newreg = "%" + str(counter)
                counter += 1
                widths[newreg] = width
                lastreg = res[len(res)-1][0]
                res.append([newreg, width, inst, [[lastreg, width], op]])
            newregs[reg] = newreg
        else:
            newinst = [newreg, width, inst, ops]
            # store intrinsic func decls
            if inst in intrinsts:
                if (inst+width) not in intrdecl:
                    if "overflow" in inst:
                        decl = "declare {" + width + ", i1} @llvm." + inst \
                                           + "." + width + "(" + width + ", " \
                                           + width + ") #" + str(len(intrdecl))
                    else:
                        decl = "declare " + width + " @llvm." + inst + "." \
                                          + width + "(" + width + ") #" \
                                          + str(len(intrdecl))
                    intrdecl[inst+width] = decl
            # Separate the instructions not dependent on branch condition
            if flagPC==1 and len(pcmap)>0:
                independent.append([newinst])
                flagPC = 0
            elif flagPC==0 and len(pcmap)>0:
                independent[-1].append(newinst)
            else:
                res.append(newinst)
    return res

def rewriteOps(inst, newregs, ops, i):
   for j, op in enumerate(ops):
       # don't change phi's first op, namely the block ref
       if inst == "phi" and j == 0:
           continue
       elif op[0][0] == "%" and not "%x" in op[0]:
           assert op[0] in newregs, "unknown reg %s, reordering bug in %s" % (op[0], i)
           ops[j][0] = newregs[op[0]]

def parseKnownBits(s,reg):
    assert "(" in s and ")" in s and "knownBits=" in s
    s = s.strip("(").strip(")").strip("knownBits=")
    zeromask = 0
    onemask = 0
    all = 0
    bit = 0
    for c in s[::-1]:
        all += 1<<bit
        if c == '0':
            zeromask += 1<<bit
        if c == '1':
            onemask += 1<<bit
        bit += 1
    print "; for " + str(bit) + " bit value " + s
    if zeromask != 0:
        knownzeros[reg] = zeromask ^ all
        print ";   val = val & " + str(zeromask ^ all)
    else:
        print ";   no zero mask needed"
    if onemask != 0:
        knownones[reg] = onemask
        print ";   val = val | " + str(onemask)
    else:
        print ";   no one mask needed"
    print

def parseNegative(s,reg, width):
    assert "(" in s and ")" in s
    s = s.strip("(").strip(")")
    NegativeMask = 0
    if s == "negative":
        NegativeMask = 1<<(width-1)
        print "; for " + str(width) + " bit value " + s
        if NegativeMask != 0:
            Negative[reg] = NegativeMask
            print ";   val = val | " + str(NegativeMask)
        else:
            print ";   no negative mask needed"
        print

def parseNonNegative(s,reg, width):
    assert "(" in s and ")" in s
    s = s.strip("(").strip(")")
    NonNegativeMask = 0
    bit = 0
    if s == "nonNegative":
        while (bit < width-1):
            NonNegativeMask += 1<<bit
            bit += 1
        print "; for " + str(width) + " bit value " + s
        if NonNegativeMask != 0:
            nonNegative[reg] = NonNegativeMask
            print ";   val = val & " + str(NonNegativeMask)
        else:
            print ";   no nonNegative mask needed"
        print

def parsePowerTwo(s,reg, width):
    assert "(" in s and ")" in s
    s = s.strip("(").strip(")")
    if s == "powerOfTwo":
        print "; for " + str(width) + " bit value " + s
        powerTwo[reg] = 1
        print "; " + str(reg) + " = 1 << %x"
        print

def parsenonZero(s,reg, width):
    assert "(" in s and ")" in s
    s = s.strip("(").strip(")")
    if s == "nonZero":
        print "; for " + str(width) + " bit value " + s
        nonZero[reg] = 1
        print "; " + str(reg) + " != 0"
        print

def parsesignBits(s,reg, width):
    assert "(" in s and ")" in s and "signBits=" in s
    s = s.strip("(").strip(")").strip("signBits=")
    print "; for signbits = " + s
    signBits[reg] = s
    print "; " + str(reg) + " signbits"
    print

def parseRange(s,reg, width):
    assert "(" in s and ")" in s and "range=[" in s
    s = s.strip("(").strip(")").strip("range=[")
    print "; for range = [" + s + ")"
    s = s.split(',')
    lower[reg] = s[0]
    upper[reg] = s[1]
    print "; " + str(reg) + " range"
    print

def printFuncHeader():
    regtmps = {}
    for i, reg in enumerate(argmap.values()):
        regtmps[reg] = reg
    masks = ""
    global MDRangePtr
    global metadataStr
    global functionDeclStr
    global functionDeclCount
    for i, reg in enumerate(argmap.values()):
        if reg in knownzeros:
            regnum = reg[1:]
            w = widths[reg]
            old = regtmps[reg]
            name = "%t1_" + regnum
            regtmps[reg] = name
            masks = "  " + old + " = and " + w + " " + name + ", " + str(knownzeros[reg]) + "\n" +\
            "  store " + w + " " + old + ", " + w + "* @glob_" + regnum + "_and\n" + masks
            regtmps[reg] = name
        if reg in knownones:
            regnum = reg[1:]
            w = widths[reg]
            old = regtmps[reg]
            name = "%t2_" + regnum
            regtmps[reg] = name
            masks = "  " + old + " = or " + w + " " + name + ", " + str(knownones[reg]) + "\n" +\
            "  store " + w + " " + old + ", " + w + "* @glob_" + regnum + "_or\n" + masks
        if reg in Negative:
            regnum = reg[1:]
            w = widths[reg]
            old = regtmps[reg]
            name = "%t3_" + regnum
            regtmps[reg] = name
            masks = "  " + old + " = or " + w + " " + name + ", " + str(Negative[reg]) + "\n" +\
            "  store " + w + " " + old + ", " + w + "* @glob_" + regnum + "_negative\n" + masks
        if reg in nonNegative:
            regnum = reg[1:]
            w = widths[reg]
            old = regtmps[reg]
            name = "%t4_" + regnum
            regtmps[reg] = name
            masks = "  " + old + " = and " + w + " " + name + ", " + str(nonNegative[reg]) + "\n" +\
            "  store " + w + " " + old + ", " + w + "* @glob_" + regnum + "_nonNegative\n" + masks
        if reg in powerTwo:
            regnum = reg[1:]
            w = widths[reg]
            old = regtmps[reg]
            name = "%t5_" + regnum
            regtmps[reg] = name
            masks = "  " + old + " = shl " + w + " " + "1" + ", " + name + "\n" +\
            "  store " + w + " " + old + ", " + w + "* @glob_" + regnum + "_powerOfTwo\n" + masks
        if reg in signBits:
            regnum = reg[1:]
            w = widths[reg]
            old = regtmps[reg]
            name = "%t6_" + regnum
            regtmps[reg] = name
            global truncCount
            global truncStr
            sign = "%s_" + str(truncCount)
            org_width = int(w.strip("i"))
            red_width = org_width + 1 - int(signBits[reg])
            masks = "  " + sign + " = trunc " + w + " " + name + " to " + "i" + str(red_width) + "\n" +\
            "  " + old + " = sext " + "i" + str(red_width) + " " + sign + " to " + w + "\n" +\
            "  store " + w + " " + old + ", " + w + "* @glob_" + regnum + "_signBits\n" + masks
            truncCount += 1
        if reg in lower or reg in upper:
            regnum = reg[1:]
            w = widths[reg]
            old = regtmps[reg]
            name = "%t8_" + regnum
            regtmps[reg] = name
            ptr = "%ptr_" + str(MDRangePtr)
            masks = "  " + ptr + " = alloca " + w + "\n" +\
            "  store " + w + " " + name + ", " + w + "* " + ptr + "\n" +\
            "  call void @fun_" + str(functionDeclCount) + "(" + w + "* " + ptr + ")\n" +\
            "  " + old + " = load " + w + ", " + w + "* " + ptr + ", !range !" + str(MDRangePtr) + "\n" +\
            "  store " + w + " " + old + ", " + w + "* @glob_" + regnum + "_range\n" + masks
            metadataStr = metadataStr + "!" + str(MDRangePtr) + " = " + "!{ " + w + " " + str(lower[reg]) + ", " + w + " " + str(upper[reg]) + " }\n"
            MDRangePtr += 1
            functionDeclStr += "declare void @fun_" + str(functionDeclCount) + "(" + w + "* " + ptr + ")\n"
            functionDeclCount += 1
        if reg in nonZero:
            regnum = reg[1:]
            w = widths[reg]
            old = regtmps[reg]
            name = "%t7_" + regnum
            regtmps[reg] = name
            ptr = "%ptr_" + str(MDRangePtr)
            masks = "  " + ptr + " = alloca " + w + "\n" +\
            "  store " + w + " " + name + ", " + w + "* " + ptr + "\n" +\
            "  call void @fun_" + str(functionDeclCount) + "(" + w + "* " + ptr + ")\n" +\
            "  " + old + " = load " + w + ", " + w + "* " + ptr + ", !range !" + str(MDRangePtr) + "\n" +\
            "  store " + w + " " + old + ", " + w + "* @glob_" + regnum + "_nonZero\n" + masks
            metadataStr = metadataStr + "!" + str(MDRangePtr) + " = " + "!{ " + w + " 1" + ", " + w + " 0"+ " }\n"
            MDRangePtr += 1
            functionDeclStr += "declare void @fun_" + str(functionDeclCount) + "(" + w + "* " + ptr + ")\n"
            functionDeclCount += 1

    # global declarations: dummy store, intrinsics
    print "target datalayout = \"e-m:e-i64:64-f80:128-n8:16:32:64-S128\"\n"
    for i, reg in enumerate(argmap.values()):
        print "@glob_" + reg[1:] + "_and = external global " + widths[reg]
        print "@glob_" + reg[1:] + "_or = external global " + widths[reg]
        print "@glob_" + reg[1:] + "_negative = external global " + widths[reg]
        print "@glob_" + reg[1:] + "_nonNegative = external global " + widths[reg]
        print "@glob_" + reg[1:] + "_powerOfTwo = external global " + widths[reg]
        print "@glob_" + reg[1:] + "_nonZero = external global " + widths[reg]
        print "@glob_" + reg[1:] + "_signBits = external global " + widths[reg]
        print "@glob_" + reg[1:] + "_range = external global " + widths[reg]
    for decl in intrdecl.values():
        print decl
    print "\ndefine " + cand[1] + " @foo(",
    # print args
    for i, reg in enumerate(argmap.values()):
        sys.stdout.write(widths[reg] + " " + regtmps[reg])
        sys.stdout.write(", ")
    print "i64 %_phiinput) {"
    # first label
    print labels[0] + ":"
    print masks

def printFuncBody(insts):
    for i in insts:
        print translateInstToLLVM(i)

def printPCs():
    combinedList = dict()
    for reg, pcs in zip(pcmap.keys(), independent):
        combinedList[reg] = pcs
    for reg, pcs in pcmap.items():
        global cnt
        global cnt2
        global cnt3
        #s = ""
        for k, pc in enumerate(pcs):
            s = ""
#            if pc[3][0][0] == "0":
            s += "  %tmp" + str(cnt) + " = icmp eq " + widths[pc[0]] + " " \
                 + pc[3][0][0] + ", " + pc[0] + "\n"
            s += "  br i1 %tmp" + str(cnt) + ", label %cont" + str(cnt2) \
                 + ", label %out" + str(cnt3) +"\n"
#                s += "  br i1 " + pc[0] + ", label %out" + str(cnt3) \
#                                    + ", label %cont" + str(cnt2) +"\n"
#            elif pc[3][0][0] == "1":
#                s += "  br i1 " + pc[0] + ", label %cont" + str(cnt2) \
#                                    + ", label %out" + str(cnt3) +"\n"
            s += "out" + str(cnt3) + ":" + "\n"
            #s += "  ret " + cand[1] + " 0\n"
            s += "  unreachable\n"
            s += "cont" + str(cnt2) + ":"
            if k < len(pcs)-1:
                s += "\n"
            cnt += 1
            cnt2 += 1
            cnt3 += 1
            print s
            #if reg in combinedList:
            if reg in combinedList and ((len(pcs)-k) == 1):
                for i in combinedList[reg]:
                    print translateInstToLLVM(i)

def printFuncFooter():
    assert cand, "there must be a candidate"
    print "  %cand = icmp eq " + cand[3][0][1] + " " \
                               + cand[3][0][0] + ", " + cand[0]
    print "  br i1 %cand, label %return, label %dead"
    print "return:"
    print "  %dummy1w = atomicrmw add i32* @dummy1, i32 1 monotonic"
    print "  ret void"
    # dummy global variable modifications
    print "dead:"
    print "  %dummy2w = atomicrmw add i32* @dummy2, i32 1 monotonic"
    print "  ret void"
    if len(pcmap):
        print "out:"
        print "  %dummy3w = atomicrmw add i32* @dummy3, i32 1 monotonic"
        print "  ret void"
    print "}"

def printLHSFuncFooter():
    if len(cand[3]) > 0:
        for k in cand[3]:
            print "  ret " + cand[1] + " " + cand[0]
    else:
        print "  ret " + cand[1] + " " + cand[0]
    print "}"

lines = readOpt()
insts = parseInsts(lines)
insts = propagateArgNames(insts)
insts = reorder(insts)
insts = rewriteRegs(insts)

assert cand, "cand inst not found"

printFuncHeader()
printFuncBody(insts)
printPCs()
printLHSFuncFooter()
print metadataStr
print functionDeclStr
